{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "PyTorch/XLA Profling Colab Tutorial",
      "provenance": [],
      "collapsed_sections": [],
      "machine_shape": "hm"
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "accelerator": "TPU"
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "YX1hxqUQn47M"
      },
      "source": [
        "## PyTorch/XLA TPU Profiling Colab tutorial\n",
        "\n",
        "*Note*: Since we're not using GCS in this tutorial, TPU side traces won't be collected. To collect full TPU traces follow [this tutorial](https://cloud.google.com/tpu/docs/pytorch-xla-performance-profiling-tpu-vm)."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "pLQPoJ6Fn8wF"
      },
      "source": [
        "### [RUNME] Install Colab compatible PyTorch/XLA wheels and dependencies\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "O53lrJMDn9Rd"
      },
      "source": [
        "!pip install cloud-tpu-client==0.10 torch==1.9.0 https://storage.googleapis.com/tpu-pytorch/wheels/torch_xla-1.9-cp37-cp37m-linux_x86_64.whl tensorboard-plugin-profile"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "rroH9yiAn-XE"
      },
      "source": [
        "### Define Parameters\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "iMdPRFXIn_jH"
      },
      "source": [
        "# Define Parameters\n",
        "import os\n",
        "FLAGS = {}\n",
        "FLAGS['data_dir'] = \"/tmp/cifar\"\n",
        "FLAGS['batch_size'] = 128\n",
        "FLAGS['num_workers'] = 4\n",
        "FLAGS['learning_rate'] = 0.02\n",
        "FLAGS['momentum'] = 0.9\n",
        "FLAGS['num_epochs'] = 200\n",
        "FLAGS['num_cores'] = 8 if os.environ.get('TPU_NAME', None) else 1\n",
        "FLAGS['log_steps'] = 20\n",
        "FLAGS['metrics_debug'] = False"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "EP5H63aViwJe"
      },
      "source": [
        "# Setup profiler env var\n",
        "os.environ['XLA_HLO_DEBUG'] = '1'"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Micd3xZvoA-c"
      },
      "source": [
        "import multiprocessing\n",
        "import numpy as np\n",
        "import os\n",
        "import sys\n",
        "import time\n",
        "import torch\n",
        "import torch.nn as nn\n",
        "import torch.nn.functional as F\n",
        "import torch.optim as optim\n",
        "import torch_xla\n",
        "import torch_xla.core.xla_model as xm\n",
        "import torch_xla.debug.metrics as met\n",
        "import torch_xla.debug.profiler as xp\n",
        "import torch_xla.distributed.parallel_loader as pl\n",
        "import torch_xla.distributed.xla_multiprocessing as xmp\n",
        "import torch_xla.utils.utils as xu\n",
        "import torchvision\n",
        "from torchvision import datasets, transforms\n",
        "\n",
        "class BasicBlock(nn.Module):\n",
        "  expansion = 1\n",
        "\n",
        "  def __init__(self, in_planes, planes, stride=1):\n",
        "    super(BasicBlock, self).__init__()\n",
        "    self.conv1 = nn.Conv2d(\n",
        "        in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)\n",
        "    self.bn1 = nn.BatchNorm2d(planes)\n",
        "    self.conv2 = nn.Conv2d(\n",
        "        planes, planes, kernel_size=3, stride=1, padding=1, bias=False)\n",
        "    self.bn2 = nn.BatchNorm2d(planes)\n",
        "\n",
        "    self.shortcut = nn.Sequential()\n",
        "    if stride != 1 or in_planes != self.expansion * planes:\n",
        "      self.shortcut = nn.Sequential(\n",
        "          nn.Conv2d(\n",
        "              in_planes,\n",
        "              self.expansion * planes,\n",
        "              kernel_size=1,\n",
        "              stride=stride,\n",
        "              bias=False), nn.BatchNorm2d(self.expansion * planes))\n",
        "\n",
        "  def forward(self, x):\n",
        "    out = F.relu(self.bn1(self.conv1(x)))\n",
        "    out = self.bn2(self.conv2(out))\n",
        "    out += self.shortcut(x)\n",
        "    out = F.relu(out)\n",
        "    return out\n",
        "\n",
        "\n",
        "class ResNet(nn.Module):\n",
        "\n",
        "  def __init__(self, block, num_blocks, num_classes=10):\n",
        "    super(ResNet, self).__init__()\n",
        "    self.in_planes = 64\n",
        "\n",
        "    self.conv1 = nn.Conv2d(\n",
        "        3, 64, kernel_size=3, stride=1, padding=1, bias=False)\n",
        "    self.bn1 = nn.BatchNorm2d(64)\n",
        "    self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)\n",
        "    self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)\n",
        "    self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)\n",
        "    self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)\n",
        "    self.linear = nn.Linear(512 * block.expansion, num_classes)\n",
        "\n",
        "  def _make_layer(self, block, planes, num_blocks, stride):\n",
        "    strides = [stride] + [1] * (num_blocks - 1)\n",
        "    layers = []\n",
        "    for stride in strides:\n",
        "      layers.append(block(self.in_planes, planes, stride))\n",
        "      self.in_planes = planes * block.expansion\n",
        "    return nn.Sequential(*layers)\n",
        "\n",
        "  def forward(self, x):\n",
        "    out = F.relu(self.bn1(self.conv1(x)))\n",
        "    out = self.layer1(out)\n",
        "    out = self.layer2(out)\n",
        "    out = self.layer3(out)\n",
        "    out = self.layer4(out)\n",
        "    out = F.avg_pool2d(out, 4)\n",
        "    out = torch.flatten(out, 1)\n",
        "    out = self.linear(out)\n",
        "    return F.log_softmax(out, dim=1)\n",
        "\n",
        "\n",
        "def ResNet18():\n",
        "  return ResNet(BasicBlock, [2, 2, 2, 2])"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "LyBPVi71h7ug"
      },
      "source": [
        "In the following cell we define the training loops and most importantly add tracing annotations `xp.StepTrace` and `xp.Trace` to that we'll be able to inspect in our profiler traces view on Tensorboard. `xp.StepTrace` specifically should be annotated only once per step as it denotes a full step and is used to calculate the step time for the model and is displayed on Tensorboard profile summary page. The `xp.Trace` context manager annotation can be sprinkled around on whichever parts you want a more detailed timeline of."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "8vMl96KLoCq8"
      },
      "source": [
        "SERIAL_EXEC = xmp.MpSerialExecutor()\n",
        "# Only instantiate model weights once in memory.\n",
        "WRAPPED_MODEL = xmp.MpModelWrapper(ResNet18())\n",
        "\n",
        "def train_resnet18(training_started):\n",
        "  torch.manual_seed(1)\n",
        "\n",
        "  # We are using fake data here (not real CIFAR dataset).\n",
        "  train_dataset_len = 50000  # Number of example in CIFAR train set.\n",
        "  train_loader = xu.SampleGenerator(\n",
        "      data=(torch.zeros(FLAGS['batch_size'], 3, 32,\n",
        "                        32), torch.zeros(FLAGS['batch_size'],\n",
        "                                          dtype=torch.int64)),\n",
        "      sample_count=train_dataset_len // FLAGS['batch_size'] //\n",
        "      xm.xrt_world_size())\n",
        "  test_loader = xu.SampleGenerator(\n",
        "      data=(torch.zeros(FLAGS['batch_size'], 3, 32,\n",
        "                        32), torch.zeros(FLAGS['batch_size'],\n",
        "                                          dtype=torch.int64)),\n",
        "      sample_count=10000 // FLAGS['batch_size'] // xm.xrt_world_size())\n",
        "\n",
        "  # Get loss function, optimizer, and model\n",
        "  device = xm.xla_device()\n",
        "  model = WRAPPED_MODEL.to(device)\n",
        "  optimizer = optim.SGD(model.parameters(), lr=FLAGS['learning_rate'],\n",
        "                        momentum=FLAGS['momentum'], weight_decay=5e-4)\n",
        "  loss_fn = nn.NLLLoss()\n",
        "\n",
        "  server = xp.start_server(9012)\n",
        "\n",
        "  def train_loop_fn(loader):\n",
        "    tracker = xm.RateTracker()\n",
        "    model.train()\n",
        "    for x, (data, target) in enumerate(loader):\n",
        "      if x == 5:\n",
        "        training_started.set()\n",
        "      # Let's now profile the training step.\n",
        "      with xp.StepTrace('train_loop', step_num=x):\n",
        "        # This profiles the construction of the graph.\n",
        "        with xp.Trace('build_graph'):\n",
        "          optimizer.zero_grad()\n",
        "          output = model(data)\n",
        "          loss = loss_fn(output, target)\n",
        "          loss.backward()\n",
        "\n",
        "        xm.optimizer_step(optimizer)\n",
        "        tracker.add(FLAGS['batch_size'])\n",
        "        if x % FLAGS['log_steps'] == 0:\n",
        "          print('[xla:{}]({}) Loss={:.5f} Rate={:.2f} GlobalRate={:.2f} Time={}'.format(\n",
        "              xm.get_ordinal(), x, loss.item(), tracker.rate(),\n",
        "              tracker.global_rate(), time.asctime()), flush=True)\n",
        "\n",
        "  def test_loop_fn(loader):\n",
        "    total_samples = 0\n",
        "    correct = 0\n",
        "    model.eval()\n",
        "    data, pred, target = None, None, None\n",
        "    for data, target in loader:\n",
        "      output = model(data)\n",
        "      pred = output.max(1, keepdim=True)[1]\n",
        "      correct += pred.eq(target.view_as(pred)).sum().item()\n",
        "      total_samples += data.size()[0]\n",
        "\n",
        "    accuracy = 100.0 * correct / total_samples\n",
        "    print('[xla:{}] Accuracy={:.2f}%'.format(\n",
        "        xm.get_ordinal(), accuracy), flush=True)\n",
        "    return accuracy, data, pred, target\n",
        "\n",
        "  # Train and eval loops\n",
        "  accuracy = 0.0\n",
        "  data, pred, target = None, None, None\n",
        "  for epoch in range(1, FLAGS['num_epochs'] + 1):\n",
        "    para_loader = pl.ParallelLoader(train_loader, [device])\n",
        "    train_loop_fn(para_loader.per_device_loader(device))\n",
        "    xm.master_print(\"Finished training epoch {}\".format(epoch))\n",
        "\n",
        "    para_loader = pl.ParallelLoader(test_loader, [device])\n",
        "    accuracy, data, pred, target  = test_loop_fn(para_loader.per_device_loader(device))\n",
        "    if FLAGS['metrics_debug']:\n",
        "      xm.master_print(met.metrics_report(), flush=True)\n",
        "\n",
        "  return accuracy, data, pred, target\n",
        "  "
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "_2nL4HmloEyl"
      },
      "source": [
        "# Start training processes\n",
        "def _mp_fn(rank, flags, training_started):\n",
        "  global FLAGS\n",
        "  FLAGS = flags\n",
        "  torch.set_default_tensor_type('torch.FloatTensor')\n",
        "  accuracy, data, pred, target = train_resnet18(training_started)\n",
        "  if rank == 0:\n",
        "    # Retrieve tensors that are on TPU core 0 and plot.\n",
        "    plot_results(data.cpu(), pred.cpu(), target.cpu())\n",
        "\n",
        "def target_fn(training_started):\n",
        "  sys.stdout = open('training_logs.stdout', 'w')\n",
        "  sys.stderr = open('training_logs.stderr', 'w')\n",
        "  xmp.spawn(_mp_fn, args=(FLAGS, training_started,),\n",
        "            nprocs=FLAGS['num_cores'], start_method='fork')\n",
        "  \n",
        "training_started = multiprocessing.Event()\n",
        "p = multiprocessing.Process(target=target_fn, args=(training_started,))\n",
        "p.start()"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "bcS58faWifX7"
      },
      "source": [
        "The following cell first waits for the training to start up and then subsequently traces both the client VM side (i.e., where the XLA graph is built and input pipeline is run) and the TPU device side (where the actual compilation and execution happens). However, note that since we're running on Colab and not using GCS in this tutorial, TPU side traces won't be collected. To collect full TPU traces follow this [tutorial](https://cloud.google.com/tpu/docs/pytorch-xla-performance-profiling-tpu-vm)."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "zppUkQI3fv2p"
      },
      "source": [
        "training_started.wait(120)\n",
        "\n",
        "import re\n",
        "tpu_ip = re.match('grpc\\://((\\d{1,3}\\.){3}\\d{1,3})\\:\\d{4}',\n",
        "             os.environ.get('TPU_NAME')).group(1)\n",
        "xp.trace('localhost:9012', '/tmp/tensorboard')  # client side profiling\n",
        "xp.trace(f'{tpu_ip}:8466', '/tmp/tensorboard')  # need GCS bucket for all traces to be written"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Efdo1gx4bYRY"
      },
      "source": [
        "%load_ext tensorboard\n",
        "%tensorboard --logdir /tmp/tensorboard --load_fast=false\n",
        "# Click on \"INACTIVE\" dropdown and select \"PROFILE\""
      ],
      "execution_count": null,
      "outputs": []
    }
  ]
}
